{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 6, 5)\n",
    "        # 1 input image channel, 6 output channels, 5x5 square convolution kernel\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        # an affine operation: y = Wx + b\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n",
    "        x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n",
    "        x = x.view(-1, self.num_flat_features(x))\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "    def num_flat_features(self, x):\n",
    "        size = x.size()[1:]\n",
    "        num_features = 1\n",
    "        for s in size:\n",
    "            num_features *= s\n",
    "        return num_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Net(\n",
      "  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
      "  (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
      "  (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
      ") <class 'type'>\n"
     ]
    }
   ],
   "source": [
    "net = Net()\n",
    "print(net, type(Net))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n",
      "torch.Size([6, 1, 5, 5])\n"
     ]
    }
   ],
   "source": [
    "params = list(net.parameters())\n",
    "# print(params)\n",
    "print(len(params))\n",
    "print(params[0].size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义输入，前向传播"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.0851, -0.0634, -0.0953, -0.0334, -0.0953, -0.1915, -0.0252,\n",
      "          0.1188, -0.0320,  0.1387]])\n",
      "torch.FloatTensor <class 'torch.Tensor'>\n",
      "torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "from torch.autograd import Variable\n",
    "input = Variable(torch.randn(1, 1, 32, 32))\n",
    "output = net(input)\n",
    "print(output)\n",
    "print(output.type(), type(output))\n",
    "print(output.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.0851],\n",
      "        [-0.0634],\n",
      "        [-0.0953],\n",
      "        [-0.0334],\n",
      "        [-0.0953],\n",
      "        [-0.1915],\n",
      "        [-0.0252],\n",
      "        [ 0.1188],\n",
      "        [-0.0320],\n",
      "        [ 0.1387]])\n",
      "tensor([[-0.0851],\n",
      "        [-0.0634],\n",
      "        [-0.0953],\n",
      "        [-0.0334],\n",
      "        [-0.0953],\n",
      "        [-0.1915],\n",
      "        [-0.0252],\n",
      "        [ 0.1188],\n",
      "        [-0.0320],\n",
      "        [ 0.1387]])\n",
      "tensor(1.00000e-02 *\n",
      "       [-8.5070])\n",
      "torch.Size([1])\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "# 对所有的参数的梯度缓冲区进行归零\n",
    "net.zero_grad()\n",
    "\n",
    "# 使用随机的梯度进行反向传播\n",
    "out.backward(torch.randn(1, 10))\n",
    "\"\"\"\n",
    "output_temp = output.clone() # 深复制，不会改变output\n",
    "print(output_temp.transpose_(0, 1)) # 这里会改变output_temp形状\n",
    "output = output[:,]\n",
    "print(output_temp)\n",
    "print(output_temp[0])\n",
    "print(output_temp[0].size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 设置目标值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.])\n",
      "torch.Size([10])\n",
      "tensor([[  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.]])\n",
      "torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "# 设置目标值\n",
    "target = Variable(torch.arange(1, 11))\n",
    "print(target)\n",
    "print(target.size())\n",
    "target = target.view(1, -1)\n",
    "print(target)\n",
    "print(target.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 设置损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.nn.modules.loss.MSELoss'>\n"
     ]
    }
   ],
   "source": [
    "# 设置损失函数\n",
    "criterion = nn.MSELoss()\n",
    "print(type(criterion))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 计算损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.Tensor'>\n",
      "torch.FloatTensor\n",
      "tensor(38.5867)\n"
     ]
    }
   ],
   "source": [
    "loss = criterion(output, target)\n",
    "print(type(loss))\n",
    "print(loss.type())\n",
    "print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['__abs__', '__add__', '__and__', '__array__', '__array_wrap__', '__bool__', '__class__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__div__', '__doc__', '__eq__', '__float__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__iadd__', '__iand__', '__idiv__', '__ilshift__', '__imul__', '__index__', '__init__', '__init_subclass__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__', '__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__', '__long__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__module__', '__mul__', '__ne__', '__neg__', '__new__', '__nonzero__', '__or__', '__pow__', '__radd__', '__rdiv__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__', '__rpow__', '__rshift__', '__rsub__', '__rtruediv__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__str__', '__sub__', '__subclasshook__', '__truediv__', '__weakref__', '__xor__', '_abs', '_addmv', '_addmv_', '_addr', '_addr_', '_argmax', '_argmin', '_backward_hooks', '_base', '_cast_Half', '_cast_double', '_cast_float', '_cast_int', '_cast_int16_t', '_cast_int64_t', '_cast_int8_t', '_cast_uint8_t', '_cdata', '_ceil', '_copy_ignoring_overlaps_', '_cos', '_dimI', '_dimV', '_dot', '_exp', '_fft_with_size', '_floor', '_ger', '_grad', '_grad_fn', '_indices', '_log', '_make_subclass', '_mm', '_mv', '_nnz', '_round', '_s_where', '_sin', '_sparse_mask', '_sqrt', '_standard_gamma', '_standard_gamma_grad', '_trunc', '_unique', '_values', '_version', 'abs', 'abs_', 'acos', 'acos_', 'add', 'add_', 'addbmm', 'addbmm_', 'addcdiv', 'addcdiv_', 'addcmul', 'addcmul_', 'addmm', 'addmm_', 'addmv', 'addmv_', 'addr', 'addr_', 'all', 'allclose', 'any', 'apply_', 'argmax', 'argmin', 'as_strided', 'as_strided_', 'asin', 'asin_', 'atan', 'atan2', 'atan2_', 'atan_', 'backward', 'baddbmm', 'baddbmm_', 'bernoulli', 'bernoulli_', 'bmm', 'btrifact', 'btrifact_with_info', 'btrisolve', 'byte', 'cauchy_', 'ceil', 'ceil_', 'char', 'chunk', 'clamp', 'clamp_', 'clone', 'coalesce', 'contiguous', 'conv_tbc', 'copy_', 'cos', 'cos_', 'cosh', 'cosh_', 'cpu', 'cross', 'cuda', 'cumprod', 'cumsum', 'data', 'data_ptr', 'det', 'detach', 'detach_', 'device', 'diag', 'digamma', 'digamma_', 'dim', 'dist', 'div', 'div_', 'dot', 'double', 'dtype', 'eig', 'element_size', 'eq', 'eq_', 'equal', 'erf', 'erf_', 'erfinv', 'erfinv_', 'exp', 'exp_', 'expand', 'expand_as', 'expm1', 'expm1_', 'exponential_', 'fft', 'fill_', 'float', 'floor', 'floor_', 'fmod', 'fmod_', 'frac', 'frac_', 'gather', 'ge', 'ge_', 'gels', 'geometric_', 'geqrf', 'ger', 'gesv', 'get_device', 'grad', 'grad_fn', 'gt', 'gt_', 'half', 'histc', 'ifft', 'index', 'index_add', 'index_add_', 'index_copy', 'index_copy_', 'index_fill', 'index_fill_', 'index_put_', 'index_select', 'int', 'inverse', 'irfft', 'is_coalesced', 'is_contiguous', 'is_cuda', 'is_distributed', 'is_floating_point', 'is_leaf', 'is_nonzero', 'is_pinned', 'is_same_size', 'is_set_to', 'is_shared', 'is_signed', 'is_sparse', 'isclose', 'item', 'kthvalue', 'layout', 'le', 'le_', 'lerp', 'lerp_', 'lgamma', 'lgamma_', 'log', 'log10', 'log10_', 'log1p', 'log1p_', 'log2', 'log2_', 'log_', 'log_normal_', 'logdet', 'long', 'lt', 'lt_', 'map2_', 'map_', 'masked_copy', 'masked_copy_', 'masked_fill', 'masked_fill_', 'masked_scatter', 'masked_scatter_', 'masked_select', 'matmul', 'max', 'mean', 'median', 'min', 'mm', 'mode', 'mul', 'mul_', 'multinomial', 'mv', 'name', 'narrow', 'ndimension', 'ne', 'ne_', 'neg', 'neg_', 'nelement', 'new', 'new_empty', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'nonzero', 'norm', 'normal_', 'numel', 'numpy', 'orgqr', 'ormqr', 'output_nr', 'permute', 'pin_memory', 'polygamma', 'polygamma_', 'potrf', 'potri', 'potrs', 'pow', 'pow_', 'prod', 'pstrf', 'put_', 'qr', 'random_', 'reciprocal', 'reciprocal_', 'record_stream', 'register_hook', 'reinforce', 'relu', 'relu_', 'remainder', 'remainder_', 'renorm', 'renorm_', 'repeat', 'requires_grad', 'requires_grad_', 'reshape', 'resize', 'resize_', 'resize_as', 'resize_as_', 'retain_grad', 'rfft', 'round', 'round_', 'rsqrt', 'rsqrt_', 'scatter', 'scatter_', 'scatter_add', 'scatter_add_', 'select', 'set_', 'shape', 'share_memory_', 'short', 'sigmoid', 'sigmoid_', 'sign', 'sign_', 'sin', 'sin_', 'sinh', 'sinh_', 'size', 'slice', 'slogdet', 'smm', 'sort', 'split', 'split_with_sizes', 'sqrt', 'sqrt_', 'squeeze', 'squeeze_', 'sspaddmm', 'std', 'stft', 'storage', 'storage_offset', 'storage_type', 'stride', 'sub', 'sub_', 'sum', 'svd', 'symeig', 't', 't_', 'take', 'tan', 'tan_', 'tanh', 'tanh_', 'to', 'to_dense', 'tolist', 'topk', 'trace', 'transpose', 'transpose_', 'tril', 'tril_', 'triu', 'triu_', 'trtrs', 'trunc', 'trunc_', 'type', 'type_as', 'unfold', 'uniform_', 'unique', 'unsqueeze', 'unsqueeze_', 'var', 'view', 'view_as', 'where', 'zero_']\n"
     ]
    }
   ],
   "source": [
    "print(dir(loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<MseLossBackward object at 0x000002DB8F1DC518>\n",
      "<class 'MseLossBackward'>\n",
      "<AliasBackward object at 0x000002DB8F1DC518>\n",
      "<AddmmBackward object at 0x000002DB8F1DC6D8>\n"
     ]
    }
   ],
   "source": [
    "print(loss.grad_fn)\n",
    "# print(loss.grad_fn.type())\n",
    "print(type(loss.grad_fn))\n",
    "\n",
    "print(loss.grad_fn.next_functions[0][0])\n",
    "print(loss.grad_fn.next_functions[0][0].next_functions[0][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 反向传播"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1.bias before backward\n",
      "Parameter containing:\n",
      "tensor([-0.1806,  0.0430,  0.1913, -0.1282, -0.1779, -0.0264])\n",
      "conv1.bias.grad before backward\n",
      "None\n",
      "conv1.bias after backward\n",
      "Parameter containing:\n",
      "tensor([-0.1806,  0.0430,  0.1913, -0.1282, -0.1779, -0.0264])\n",
      "conv1.bias.grad after backward\n",
      "tensor([ 0.0923,  0.0654, -0.0013, -0.1094, -0.1548,  0.0237])\n"
     ]
    }
   ],
   "source": [
    "# 对所有的参数的梯度缓冲区进行归零\n",
    "net.zero_grad()\n",
    "\n",
    "print('conv1.bias before backward')\n",
    "print(net.conv1.bias)\n",
    "print('conv1.bias.grad before backward')\n",
    "print(net.conv1.bias.grad)\n",
    "\n",
    "# 将loss的梯度进行反向传播\n",
    "loss.backward()\n",
    "print('conv1.bias after backward')\n",
    "print(net.conv1.bias)\n",
    "print('conv1.bias.grad after backward')\n",
    "print(net.conv1.bias.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 权重更新（根据梯度来更新权重）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 更新权重\n",
    "learning_rate = 0.01\n",
    "for f in net.parameters():\n",
    "    f.data.sub_(f.grad.data * learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1.bias.grad after update\n",
      "tensor([ 0.0923,  0.0654, -0.0013, -0.1094, -0.1548,  0.0237])\n"
     ]
    }
   ],
   "source": [
    "print('conv1.bias.grad after update')\n",
    "print(net.conv1.bias.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1.bias after update\n",
      "Parameter containing:\n",
      "tensor([-0.1816,  0.0423,  0.1913, -0.1271, -0.1764, -0.0266])\n"
     ]
    }
   ],
   "source": [
    "print('conv1.bias after update')\n",
    "print(net.conv1.bias)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用内置梯度更新策略来更新权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "loss is 37.5668449\n",
      "\n",
      "loss is 37.2013550\n",
      "\n",
      "loss is 36.7162094\n",
      "\n",
      "loss is 35.9450073\n",
      "\n",
      "loss is 34.7345657\n",
      "\n",
      "loss is 32.5138130\n",
      "\n",
      "loss is 27.9811897\n",
      "\n",
      "loss is 18.1138763\n",
      "\n",
      "loss is 2.5353034\n",
      "\n",
      "loss is 2.2518749\n"
     ]
    }
   ],
   "source": [
    "# 将神经网络待更新的权重传入优化器\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.001 * 6)\n",
    "# 优化器梯度置零\n",
    "optimizer.zero_grad()\n",
    "    \n",
    "for i in range(10):\n",
    "    optimizer.zero_grad()\n",
    "    # 前向传播一次\n",
    "    output = net(input)\n",
    "\n",
    "    # 计算损失函数\n",
    "    loss = criterion(output, target)\n",
    "    print('\\nloss is %0.7f' % loss.item())\n",
    "  \n",
    "    # 反向传播一次，更新梯度\n",
    "    loss.backward()\n",
    "    \n",
    "#     print('\\nconv1.bias before update')\n",
    "#     print(net.conv1.bias)\n",
    "    # 更新权重\n",
    "    optimizer.step()\n",
    "#     print('\\nconv1.bias after update')\n",
    "#     print(net.conv1.bias)\n",
    "    optimizer.zero_grad()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
